# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------

from tutel import system

import os
import time
import json
import random
import argparse
import datetime
import numpy as np
from functools import partial
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist

from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy, AverageMeter

from config import get_config
from models import build_model
from data import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from utils import NativeScalerWithGradNormCount, reduce_tensor
from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_resume_helper, hook_scale_grad

from fast_gradient_method import fast_gradient_method
from projected_gradient_descent import projected_gradient_descent
from spsa import spsa
from autoattack.autoattack import AutoAttack

from fvcore.nn import FlopCountAnalysis

assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0"

# pytorch major version (1.x or 2.x)
PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])


def parse_option():
    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                        help='no: no cache, '
                             'full: cache all data, '
                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
    parser.add_argument('--pretrained',
                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true', help='Test throughput only')

    # distributed training
    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
    if PYTORCH_MAJOR_VERSION == 1:
        parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')

    parser.add_argument('--use_wandb', action = 'store_true', help='use wandb.')
    parser.add_argument('--project_name', default=None, type=str)
    parser.add_argument('--job_name', type=str, default=None,
                    help='job name for wandb.')
    
    parser.add_argument('--debug', action = 'store_true')
    parser.add_argument('--show-gate-w-stats', action='store_true', help='log the cosa weights')
    parser.add_argument('--attack', type = str, default = 'none', 
                        choices = ['pgd', 'fgm', 'spsa', 'auto', 'auto-individual', 'square', 'none'],
                         help = 'adversarial attack to select from')
    parser.add_argument('--eps', type=int, default = 1, help = 'perturbation budget in attack' )
    parser.add_argument('--compute-router-stability', action = 'store_true')

    args, unparsed = parser.parse_known_args()

    config = get_config(args)

    return args, config


def main(config):
    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    model = build_model(config)
    logger.info(str(model))

    # For Tutel MoE
    for name, param in model.named_parameters():
        if param.requires_grad == True and hasattr(param, 'skip_allreduce') and param.skip_allreduce is True:
            model.add_param_to_skip_allreduce(name)
            param.register_hook(partial(hook_scale_grad, dist.get_world_size()))
            logger.info(f"[rank{dist.get_rank()}] [{name}] skip all_reduce and div {dist.get_world_size()} for grad")

    n_parameters_single = sum(p.numel() * model.sharded_count if hasattr(p, 'skip_allreduce')
                              else p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"number of params single: {n_parameters_single}")
    n_parameters_whole = sum(p.numel() * model.sharded_count * model.global_experts if hasattr(p, 'skip_allreduce')
                             else p.numel() for p in model.parameters() if p.requires_grad)
    logger.info(f"number of params whole: {n_parameters_whole}")
    if hasattr(model, 'flops'):
        flops = model.flops()
        logger.info(f"number of GFLOPs: {flops / 1e9}")

    model.cuda(config.LOCAL_RANK)
    model_without_ddp = model

    optimizer = build_optimizer(config, model)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
    loss_scaler = NativeScalerWithGradNormCount()

    if config.TRAIN.ACCUMULATION_STEPS > 1:
        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
    else:
        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))

    if config.AUG.MIXUP > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif config.MODEL.LABEL_SMOOTHING > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    max_accuracy = 0.0

    if config.TRAIN.AUTO_RESUME and not args.debug:
        resume_file = auto_resume_helper(config.OUTPUT, config.TRAIN.MOE.SAVE_MASTER)
        if resume_file:
            if config.MODEL.RESUME:
                logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
            config.defrost()
            config.MODEL.RESUME = resume_file
            config.freeze()
            logger.info(f'auto resuming from {resume_file}')
        else:
            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')

    if config.MODEL.RESUME:
        if args.attack != 'none':
            config.defrost()
            config.EVAL_MODE = True
            config.freeze()

            _ = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)
            acc1, acc5, loss = validate(config, data_loader_val, model, attack = args.attack, eps = args.eps, compute_router_stability=args.compute_router_stability)
            logger.info(f'---------- Evaluation on {args.attack} ---------')
            logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: Top1: {acc1:.2f}% | Top5: {acc5:.2f}%")
            return

        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)
        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
        if config.EVAL_MODE:
            return

    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
        load_pretrained(config, model_without_ddp, logger)
        acc1, acc5, loss = validate(config, data_loader_val, model)
        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
        if config.EVAL_MODE:
            return

    if config.THROUGHPUT_MODE:
        throughput(data_loader_val, model, logger)
        return

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
                        loss_scaler)
        if (epoch % 5 == 0 or epoch == (config.TRAIN.EPOCHS - 1)) and not args.debug:
            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
                            logger)
        
        if args.show_gate_w_stats:
            for i, basic_layer in enumerate(model_without_ddp.layers):
                if basic_layer.use_cosa:
                    for idx, block in enumerate(basic_layer.blocks):
                        if idx in config.MODEL.SWIN_MOE.COSA_POSITIONS[2]:
                            Wmin, Wmax, Wmean, Wstd, clamp_roof, alpha = block.W
                            logger.info('-----------Gate Weight Statistics ------------')
                            logger.info(f'Min: {Wmin} | Max: {Wmax} | Mean: {Wmean} | std: {Wstd} | Clamp Roof: {clamp_roof} | Unif Mix Weight: {alpha}')


        acc1, acc5, loss = validate(config, data_loader_val, model)
        if args.use_wandb:
            wandb.log({'test_acc1':acc1, 'test_acc5':acc5, 'test_loss':loss, 'epoch':epoch})
        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
        max_accuracy = max(max_accuracy, acc1)
        logger.info(f'Max accuracy: {max_accuracy:.2f}%')
    save_checkpoint(config, 'final', model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
                    logger, zero_redundancy=True)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))


def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
    model.train()
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    loss_aux_meter = AverageMeter()
    loss_cls_meter = AverageMeter()
    norm_meter = AverageMeter()
    scaler_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):
        # s = time.time()
        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            # flops = FlopCountAnalysis(model,samples)
            # print(flops.total()/1e9)
            # assert 1==2
            
            # flops, params = profile(model, inputs=(samples,))
            # print(flops/1e9)
            # assert 1==2
            outputs, l_aux = model(samples)

        l_cls = criterion(outputs, targets)
        loss = l_cls + l_aux
        loss = loss / config.TRAIN.ACCUMULATION_STEPS

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
        grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
                                parameters=model.parameters(), create_graph=is_second_order,
                                update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
        if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
            optimizer.zero_grad()
            lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
        loss_scale_value = loss_scaler.state_dict()["scale"]

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        loss_cls_meter.update(l_cls.item(), targets.size(0))
        loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), targets.size(0))
        if grad_norm is not None:  # loss_scaler return None if not update
            norm_meter.update(grad_norm)
        scaler_meter.update(loss_scale_value)
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            wd = optimizer.param_groups[0]['weight_decay']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'loss-cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t'
                f'loss-aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')    
        # e = time.time()
        # duration = e - s
        # print(f'ms/batch: {duration*1000:5.2f}')
        #print(f'Max memory allocated: {torch.cuda.max_memory_allocated(device = None) / 1024**2}')
    epoch_time = time.time() - start
    if args.use_wandb:
        wandb.log({'train_loss':loss_meter.avg, 'epoch':epoch})
    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")


#@torch.no_grad()
def validate(config, data_loader, model, attack = 'none', eps = 1, auto_verbose_dir = 'verbose_log', compute_router_stability = False):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_cls_meter = AverageMeter()
    loss_aux_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    if attack == 'fgm' or attack == 'pgd':
        eps = eps / 255
    if attack == 'spsa':
        eps = eps / 10
    if attack.startswith('auto'):
        eps = eps / 255
        log_path = os.makedirs(os.path.join(auto_verbose_dir, 'indi.txt'), exist_ok=True)
        adversary = AutoAttack(model, norm='Linf', eps=eps, version='standard', verbose = True, log_path=log_path)
        print(f'Auto attack using perturbation budget {eps}')
    if attack == 'square':
        eps = eps /255
        log_path = os.makedirs(os.path.join(auto_verbose_dir, 'indi.txt'), exist_ok=True)
        adversary = AutoAttack(model, norm='Linf', eps=eps, version='custom', verbose = True, log_path=log_path)
        adversary.attacks_to_run = ['square']
        print(f'{attack} attack using perturbation budget {eps}')
    
    if attack == 'auto-individual':
        apgdce1 = AverageMeter()
        apgdce5 = AverageMeter()
        apgdt1 = AverageMeter()
        apgdt5 = AverageMeter()
        fab1 = AverageMeter()
        fab5 = AverageMeter()
        sq1 = AverageMeter()
        sq5 = AverageMeter()
        auto_indi_meters = {'apgd-ce': (apgdce1, apgdce5),
                            'apgd-t': (apgdt1, apgdt5),
                            'fab-t': (fab1, fab5),
                            'square': (sq1, sq5)}



    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        bs = images.shape[0]

        if attack != 'none':
            if attack == 'fgm':
                images = fast_gradient_method(model, images, eps, np.inf)
            elif attack == 'pgd':
                images = projected_gradient_descent(model, images, eps, 0.15 * eps, 20, np.inf)
            elif attack == 'spsa':
                images = spsa(model, images, eps, 20)
            elif attack == 'auto' or attack == 'square':
                images = adversary.run_standard_evaluation(images, target, bs = bs)
            elif attack == 'auto-individual':
                dict_images = adversary.run_standard_evaluation_individual(images, target, bs = bs)


        # compute output
        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            if not attack == 'auto-individual':
                output, l_aux = model(images)

                 ### router stability ###
                if compute_router_stability:
                    model_without_ddp = model.module
                    for i, basic_layer in enumerate(model_without_ddp.layers):
                        for j, block in enumerate(basic_layer.blocks):
                            if block.is_moe:
                                router_stability = block.router_stability
                                #router_stability = router_stability / len(data_loader)
                                print(f'basic layer {i} - block {j} | router stability: {router_stability}')
            else:        
                individual_outputs = {}
                for name, images in dict_images.items():
                    output, l_aux = model(images)
                    individual_outputs[name] = output


        # measure accuracy and record loss
        if not attack == 'auto-individual':
            l_cls = criterion(output, target)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            acc1 = reduce_tensor(acc1)
            acc5 = reduce_tensor(acc5)
        else:
            individual_losses = {}
            for name, images in dict_images.items():
                l_cls = criterion(output, target)
                individual_losses[name] = l_cls
            loss_list = [l_cls for l_cls in individual_losses.values()]
            l_cls = torch.mean(torch.stack(loss_list))
            
            #top1_accuracies, top5_accuracies = [], []
            for name, output in individual_outputs.items():
                acc1, acc5 = accuracy(output, target, topk = (1,5))
                acc1 = reduce_tensor(acc1)
                acc5 = reduce_tensor(acc5)
                
                acc1_meter, acc5_meter = auto_indi_meters[name]

                acc1_meter.update(acc1.item(), target.size(0))
                acc5_meter.update(acc5.item(), target.size(0))

                print(f'{name} | test_acc1: {acc1} , test_acc5: {acc5}')
                if args.use_wandb:
                    wandb.log({f'{name}_test_acc1':acc1, f'{name}_test_acc5':acc5})
            
            #     top1_accuracies.append(acc1)
            #     top5_accuracies.append(acc5)
            # acc1, acc5 = torch.mean(torch.stack(top1_accuracies)), torch.mean(torch.stack(top5_accuracies))

        if attack != 'none':
            if args.use_wandb:
                wandb.log({'test_acc1':acc1,'test_acc5':acc5})

        if not attack == 'auto-individual':
            loss_cls_meter.update(l_cls.item(), target.size(0))
            loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), target.size(0))
            acc1_meter.update(acc1.item(), target.size(0))
            acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if not attack == 'auto-individual': # normal end of step logging
            if idx % config.PRINT_FREQ == 0:
                memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
                logger.info(
                    f'Test: [{idx}/{len(data_loader)}]\t'
                    f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    f'Loss-Cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t'
                    f'Loss-Aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t'
                    f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                    f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                    f'Mem {memory_used:.0f}MB')
        else: # auto attack logging
            if idx % config.PRINT_FREQ == 0:
                memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
                logger.info(
                    f'Test: [{idx}/{len(data_loader)}]\t'
                    f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    f'Mem {memory_used:.0f}MB')
                for name, meters in auto_indi_meters:
                    acc1_meter, acc5_meter = meters[0], meters[1]
                    logger.info(f'-- {name} -- \t'  
                                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t')
    
    if not attack == 'auto-individual':
        logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')

        ### router stability ###
        if compute_router_stability:
            model_without_ddp = model.module
            for i, basic_layer in enumerate(model_without_ddp.layers):
                for idx, block in enumerate(basic_layer.blocks):
                    if block.is_moe:
                        router_stability = block.router_stability
                        #router_stability = router_stability / len(data_loader)
                        print(f'basic layer {i} - block {idx} | router stability: {router_stability}')
                    

        return acc1_meter.avg, acc5_meter.avg, loss_cls_meter.avg
    
    else:
        top1_avg, top5_avg = [], []
        for name, meters in auto_indi_meters:
            acc1_meter, acc5_meter = meters[0], meters[1]
            logger.info(f'{name} | * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
            top1_avg.append(acc1_meter.avg)
            top5_avg.append(acc5_meter.avg)
        acc1, acc5 = torch.mean(torch.stack(top1_avg)), torch.mean(torch.stack(top5_avg))
        return acc1, acc5, 0.0 # loss is unimportant here, return a 0


@torch.no_grad()
def throughput(data_loader, model, logger):
    model.eval()

    for idx, (images, _) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        batch_size = images.shape[0]
        for i in range(50):
            model(images)
        torch.cuda.synchronize()
        logger.info(f"throughput averaged with 30 times")
        tic1 = time.time()
        for i in range(30):
            model(images)
        torch.cuda.synchronize()
        tic2 = time.time()
        logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
        return


if __name__ == '__main__':
    args, config = parse_option()
    os.environ["WANDB_API_KEY"]="f9b91afe90c0f06aa89d2a428bd46dac42640bff"
    if args.use_wandb:
        import wandb
        wandb.init(project=args.project_name)
        wandb.run.name = args.job_name
        wandb.config.update(config)

    if config.AMP_OPT_LEVEL:
        print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")

    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1
    torch.cuda.set_device(config.LOCAL_RANK)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
    torch.distributed.barrier()

    seed = config.SEED + dist.get_rank()
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    # linear scale the learning rate according to total batch size, may not be optimal
    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    # gradient accumulation also need to scale the learning rate
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
    config.defrost()
    config.TRAIN.BASE_LR = linear_scaled_lr
    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
    config.TRAIN.MIN_LR = linear_scaled_min_lr
    config.freeze()

    os.makedirs(config.OUTPUT, exist_ok=True)
    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")

    if dist.get_rank() == 0:
        path = os.path.join(config.OUTPUT, "config.json")
        with open(path, "w") as f:
            f.write(config.dump())
        logger.info(f"Full config saved to {path}")

    # print config
    logger.info(config.dump())
    logger.info(json.dumps(vars(args)))

    main(config)
